{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import numpy as np\n",
    "from shutil import copyfile\n",
    "symlink = True    # If this is false the files are copied instead\n",
    "combine_train_valid = False    # If this is true, the train and valid sets are ALSO combined"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imagenet constituent samples' extraction\n",
    "This notebook shows how to construct a dataset that has no CIFAR samples. This can be used for other tasks or for assessment of models trained on CIFAR, to understand how well these models deal with distribution shift. \n",
    "\n",
    "#### ENSURE THAT CINIC-10 IS DOWNLOADED AND STORED IN ../data/cinic-10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "cinic_directory = \"../data/cinic-10\"\n",
    "imagenet_directory = \"../data/cinic-10-imagenet\"\n",
    "classes = [\"airplane\", \"automobile\", \"bird\", \"cat\", \"deer\", \"dog\", \"frog\", \"horse\", \"ship\", \"truck\"]\n",
    "sets = ['train', 'valid', 'test']\n",
    "if not os.path.exists(imagenet_directory):\n",
    "    os.makedirs(imagenet_directory)\n",
    "if not os.path.exists(imagenet_directory + '/train'):\n",
    "    os.makedirs(imagenet_directory + '/train')\n",
    "if not os.path.exists(imagenet_directory + '/test'):\n",
    "    os.makedirs(imagenet_directory + '/test')\n",
    "    \n",
    "for c in classes:\n",
    "    if not os.path.exists('{}/train/{}'.format(imagenet_directory, c)):\n",
    "        os.makedirs('{}/train/{}'.format(imagenet_directory, c))\n",
    "    if not os.path.exists('{}/test/{}'.format(imagenet_directory, c)):\n",
    "        os.makedirs('{}/test/{}'.format(imagenet_directory, c))\n",
    "    if not combine_train_valid:\n",
    "        if not os.path.exists('{}/valid/{}'.format(imagenet_directory, c)):\n",
    "            os.makedirs('{}/valid/{}'.format(imagenet_directory, c))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "for s in sets:\n",
    "    for c in classes:\n",
    "        source_directory = '{}/{}/{}'.format(cinic_directory, s, c)\n",
    "        filenames = glob.glob('{}/*.png'.format(source_directory))\n",
    "        for fn in filenames:\n",
    "            dest_fn = fn.split('/')[-1]\n",
    "            if (s == 'train' or s == 'valid') and combine_train_valid and 'cifar' not in fn.split('/')[-1]:\n",
    "                dest_fn = '{}/train/{}/{}'.format(imagenet_directory, c, dest_fn)\n",
    "                if symlink:\n",
    "                    if not os.path.islink(dest_fn):\n",
    "                        os.symlink(fn, dest_fn)\n",
    "                else:\n",
    "                    copyfile(fn, dest_fn)\n",
    "                \n",
    "            elif (s == 'train') and 'cifar' not in fn.split('/')[-1]:\n",
    "                dest_fn = '{}/train/{}/{}'.format(imagenet_directory, c, dest_fn)\n",
    "                if symlink:\n",
    "                    if not os.path.islink(dest_fn):\n",
    "                        os.symlink(fn, dest_fn)\n",
    "                else:\n",
    "                    copyfile(fn, dest_fn)\n",
    "                    \n",
    "            elif (s == 'valid') and 'cifar' not in fn.split('/')[-1]:\n",
    "                dest_fn = '{}/valid/{}/{}'.format(imagenet_directory, c, dest_fn)\n",
    "                if symlink:\n",
    "                    if not os.path.islink(dest_fn):\n",
    "                        os.symlink(fn, dest_fn)\n",
    "                else:\n",
    "                    copyfile(fn, dest_fn)\n",
    "                    \n",
    "            elif s == 'test' and 'cifar' not in fn.split('/')[-1]:\n",
    "                dest_fn = '{}/test/{}/{}'.format(imagenet_directory, c, dest_fn)\n",
    "                if symlink:\n",
    "                    if not os.path.islink(dest_fn):\n",
    "                        os.symlink(fn, dest_fn)\n",
    "                else:\n",
    "                    copyfile(fn, dest_fn)\n",
    "                    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
